{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretnote.formal.locations import (\n",
    "    SPUFieldType,\n",
    "    SPUProtocolKind,\n",
    "    SymbolicPYU,\n",
    "    SymbolicSPU,\n",
    "    SymbolicWorld,\n",
    ")\n",
    "\n",
    "sym_world = SymbolicWorld(world=frozenset((\"alice\", \"bob\")))\n",
    "sym_alice = SymbolicPYU(\"alice\")\n",
    "sym_bob = SymbolicPYU(\"bob\")\n",
    "sym_spu = SymbolicSPU(\n",
    "    world=frozenset((\"alice\", \"bob\")),\n",
    "    protocol=SPUProtocolKind.SEMI2K,\n",
    "    field=SPUFieldType.FM128,\n",
    "    fxp_fraction_bits=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "from secretnote.formal.locations import SFConfigSimulation, PortBinding\n",
    "\n",
    "ray.shutdown()\n",
    "\n",
    "sym_world.reify(SFConfigSimulation())\n",
    "\n",
    "alice = sym_alice.reify()\n",
    "bob = sym_bob.reify()\n",
    "spu = sym_spu.reify(\n",
    "    alice=PortBinding(announced_as=\"127.0.0.1:32767\"),\n",
    "    bob=PortBinding(announced_as=\"127.0.0.1:32768\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretnote.instrumentation.sdk import create_profiler, setup_tracing\n",
    "\n",
    "setup_tracing()\n",
    "\n",
    "profiler = create_profiler()\n",
    "profiler.start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Yao's Millionaires' Problem in SecretFlow\"\"\"\n",
    "\n",
    "import jax\n",
    "import secretflow\n",
    "\n",
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "\n",
    "def make_money(seed: jax.random.KeyArray, generation: int) -> jax.Array:\n",
    "    for _ in range(generation):\n",
    "        seed, subkey = jax.random.split(seed)\n",
    "    return jax.random.randint(seed, shape=(), minval=10**6, maxval=10**9)\n",
    "\n",
    "\n",
    "def compare(a: jax.Array, b: jax.Array) -> jax.Array:\n",
    "    return a > b\n",
    "\n",
    "\n",
    "balance_alice = alice(make_money)(key, 3)\n",
    "balance_bob = bob(make_money)(key, 2)\n",
    "\n",
    "balance_alice = balance_alice.to(spu)\n",
    "balance_bob = balance_bob.to(spu)\n",
    "\n",
    "alice_is_richer = spu(compare)(balance_alice, balance_bob)\n",
    "\n",
    "alice_is_richer = secretflow.reveal(alice_is_richer)\n",
    "print(f\"{alice_is_richer=}\")\n",
    "\n",
    "account_alice, account_bob = secretflow.reveal((balance_alice, balance_bob))\n",
    "\n",
    "print(f\"{account_alice=}\")\n",
    "print(f\"{account_bob=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretnote.display import visualize_run\n",
    "\n",
    "profiler.stop()\n",
    "visualize_run(profiler)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
